#!/usr/bin/python

''' Functions to estimate window topics '''

import os
import sys
import random
import operator
import numpy as np
import text.util
import unsupervised.nmf
import unsupervised.rankings
import gensim
from gensim.corpora.dictionary import Dictionary
from gensim.models.coherencemodel import CoherenceModel
import pickle
import multiprocessing


class FitWindow:
    
    def __init__(self, texts, k_list, maxiter = 200):
        self.texts = texts
        self.k_list = k_list
        self.maxiter = maxiter
                
        if len(texts) != len(k_list):
            raise ValueError("The length of texts and k_list must be equal.")
        
    def _fit_window(self, text, k, verbose = False, keywords = 10):
        # NMF implementation
        impl = unsupervised.nmf.SklNMF(max_iters = self.maxiter, 
                                       init_strategy = "nndsvd")
        
        # Fit NMF
        impl.apply(text['X'], k)
        
        # Create a disjoint partition of documents
        partition = impl.generate_partition()
        
        # Create topic labels
        topic_labels = []
        for i in range(k):
            topic_labels.append("%s_%02d" % (text['dir_name'], (i+1)))
        
        # Create term rankings for each topic
        term_rankings = []
        for topic_index in range(k):
            ranked_term_indices = impl.rank_terms( topic_index )
            term_ranking = [text['terms'][i] for i in ranked_term_indices]
            term_rankings.append(term_ranking)
            
        # Print out the top terms?
        if verbose:
            print(unsupervised.rankings.format_term_rankings(term_rankings, top=keywords))
        
        return [text['doc_ids'], text['terms'], term_rankings, partition, 
                impl.W, impl.H, topic_labels]
    
    def fit(self):
        window_topics = []
        
        for i, text in enumerate(self.texts):
            print("Fitting window = %s, k = %s" % (text['dir_name'], 
                                                   self.k_list[i]))
            
            window_results = self._fit_window(text, self.k_list[i])
            
            window_topics.append({'dir_name': text['dir_name'],
                                  'k': self.k_list[i],
                                  'window_results': window_results})
        
        return window_topics

class SelectWindowTopics:
    
    def __init__(self, time_slice_data, texts, k_min, k_max, step = 1):
        self.time_slice_data = time_slice_data
        self.texts = texts
        self.k_min = k_min
        self.k_max = k_max
        self.step = step
    
    def _extract_window_keywords(self, window_topic, num_terms=10):
        ''' Returns the top n window topic keywords from a fitted
            model. 
            
            Inputs
            ------
            num_terms = the number of keywords to return.
            
            Output
            ------
            Returns a list of dictionaries with topic ids and keywords.
        '''
        
        keywords = []
        
        # Pull the ranked terms for each topic (i.e., 
        # window_topic['window_results'][2]) and save keyword
        # dictionary.
        for i,terms in enumerate(window_topic['window_results'][2]):
            keywords.append({'topic': i, 'keywords': ' '.join(terms[0:num_terms])})
        
        return (window_topic['dir_name'], keywords)

    
    def get_coherence(self, window_topic):
        dir_name, keywords = self._extract_window_keywords(window_topic)
        
        topics = [row['keywords'].split(' ') for row in keywords]
        for topic in topics:        
            print(topic)
        
        # Estimate coherence by topic
        co = CoherenceModel(topics=topics,
                            texts=self.texts_for_gensim,
                            dictionary=self.dictionary,
                            coherence=self.coherence_measure)
        
        coherence_by_topic = co.get_coherence_per_topic()
        
        # Aggregate via the median
        median_coherence = np.median(coherence_by_topic)
            
        return {'dir_name': dir_name, 
                'k': self.k, 
                'median_cv': median_coherence}
    
    def _prepare_results(self, results):
        # Get column labels (k)  and row labels (time segment)
        column_labels = [row[0]['k'] for row in results]
        row_labels = [row['dir_name'] for row in results[0]]
        
        # Extract coherence score for each window and each k
        scores = []
        for t in range(len(row_labels)):
            score = []
            for k in range(len(column_labels)):
                score.append(results[k][t]['median_cv'])
            
            scores.append(score)
        
        scores_matrix = np.array(scores)
        return (row_labels, column_labels, scores_matrix)         
    
    def select(self, coherence_measure = 'c_v', num_terms = 20, save = None):
        ''' Get coherence using the gensim Topic Coherence pipeline '''
        
        # Initialize inputs for for gensim's CoherenceModel class
        print("Preparing data...")
        self.coherence_measure = coherence_measure
        self.texts_for_gensim = [row[1].split(' ') for row in self.time_slice_data]   
        self.dictionary = Dictionary(self.texts_for_gensim)
        
        print("Running topic selection for k = %s to k = %s, using %s" % (self.k_min, self.k_max, coherence_measure))
        results = []
        for k in range(self.k_min, self.k_max, self.step):
            # Use symetric k for windows
            self.k = k
            k_list = [self.k for i in range(len(self.texts))]
            
            # Fit window topic model
            print("Fitting window topic model and estimating coherence for k = %s..." % k)
            fw = FitWindow(self.texts, k_list)
            window_topics = fw.fit()
            
            # Estimate coherence
            print("\nEstimating coherence score for k = %s...\n" % self.k)
            results.append([self.get_coherence(window_topic) for window_topic in window_topics])
            
            if save != None:
                with open(save, 'w') as pfile:
                    pickle.dump(results, pfile)
        
        row_labels, column_labels, results_matrix = self._prepare_results(results)
        
        # Print top 3 solutions for each window
        try:
            for i in range(len(row_labels)):
                idx = np.argpartition(results_matrix[i,:], -3)[-3:]
                print("Top k for window = %s: %s, %s, %s" % (row_labels[i], 
                                                             column_labels[idx[0]], 
                                                             column_labels[idx[1]], 
                                                             column_labels[idx[2]]))
        except:
            print("Cannot print top 3 topic numbers. Less than 3 topic solutions estimated.")
            
        return {'column_labels': column_labels,
                'row_labels': row_labels,
                'selection_matrix': results_matrix}

